{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DivToMul Relay 变换\n",
    "\n",
    "先从 Relay 变换，开始了解 TVM 的一些 FFI 机制。\n",
    "\n",
    "研读源码，可以看出 `tvm/src/relay/transforms/` 定义了大量 Relay 变换实现。下面挑选 `tvm/src/relay/transforms/div_to_mul.cc` 中的 `DivToMul` Pass，以了解 Relay 变换是如何生效的。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```c++\n",
    "namespace tvm {\n",
    "namespace relay {\n",
    "namespace transform {\n",
    "Pass DivToMul() {\n",
    "  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =\n",
    "      [=](Function f, IRModule m, PassContext pc) {\n",
    "        return Downcast<Function>(DivToMulRewrite().Mutate(f));\n",
    "      };\n",
    "  return CreateFunctionPass(pass_func, 0, \"DivToMul\", {\"InferType\", \"FoldConstant\"});\n",
    "}\n",
    "// 注册到全局\n",
    "TVM_REGISTER_GLOBAL(\"relay._transform.DivToMul\").set_body_typed(DivToMul);\n",
    "}\n",
    "}\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里在名称空间 `tvm::relay::transform` 下定义变换函数 {func}`DivToMul` 并将其注册到全局。\n",
    "\n",
    "````{dropdown}\n",
    "```c++\n",
    "namespace tvm {\n",
    "namespace relay {\n",
    "class DivToMulRewrite : public MixedModeMutator {\n",
    "  Expr Rewrite_(const CallNode* pre, const Expr& post) final {\n",
    "    if (const CallNode* call_node = post.as<CallNode>()) {\n",
    "      if (call_node->op == Op::Get(\"divide\")) {\n",
    "        auto rhs = call_node->args[1].as<ConstantNode>();\n",
    "        if (rhs != nullptr) {\n",
    "          auto inv =\n",
    "              runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device);\n",
    "          std::string dtype = DLDataType2String(rhs->data.DataType());\n",
    "          if (dtype == \"float32\") {\n",
    "            float rhs_val = static_cast<float*>(rhs->data->data)[0];\n",
    "            // Check for division by zero\n",
    "            if (rhs_val == 0.) {\n",
    "              return post;\n",
    "            }\n",
    "            static_cast<float*>(inv->data)[0] = 1. / rhs_val;\n",
    "          } else if (dtype == \"float64\") {\n",
    "            double rhs_val = static_cast<double*>(rhs->data->data)[0];\n",
    "            // Check for division by zero\n",
    "            if (rhs_val == 0.) {\n",
    "              return post;\n",
    "            }\n",
    "            static_cast<double*>(inv->data)[0] = 1. / rhs_val;\n",
    "          } else if (dtype == \"float16\") {\n",
    "            // Do f16 math in f32\n",
    "            float rhs_val = __gnu_h2f_ieee(static_cast<uint16_t*>(rhs->data->data)[0]);\n",
    "            // Check for division by zero\n",
    "            if (rhs_val == 0.) {\n",
    "              return post;\n",
    "            }\n",
    "            static_cast<uint16_t*>(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val);\n",
    "          } else {\n",
    "            // Cannot do 1/int because it will truncate\n",
    "            return post;\n",
    "          }\n",
    "          return Multiply(call_node->args[0], Constant(inv));\n",
    "        }\n",
    "      }\n",
    "    }\n",
    "    return post;\n",
    "  }\n",
    "};\n",
    "}\n",
    "}\n",
    "```\n",
    "````\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "想要 Python 端使用，需要在 `tvm/python/tvm/relay/transform/transform.py` 中定义："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "def DivToMul():\n",
    "    \"\"\"Transform division by a constant to multiplication by the inverse of the constant\"\"\"\n",
    "    return _ffi_api.DivToMul()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "关键点就在于：`_ffi_api`，即 `tvm/python/tvm/relay/transform/_ffi_api.py` 中的：\n",
    "\n",
    "```python\n",
    "tvm._ffi._init_api(\"relay._transform\", __name__)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Python 端测试代码见：[除法转乘法](../../read/transforms/div-to-mul)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
